import numpy as np  # NumPy是一个用于科学计算的基础包,用于处理大型多维数组和矩阵
import faiss  # FAISS库用于高效的相似度搜索和稠密向量的聚类

# 定义FaissKNeighbors类，用于执行基于FAISS的K近邻搜索
class FaissKNeighbors:
    # 类初始化函数：初始化k值，FAISS资源对象res，以及用于存储数据的索引
    def __init__(self, k=1, res=None):
        self.index = None  # 用于存储训练数据的索引
        self.y = None  # 用于存储训练数据的标签
        self.k = k  # 最近邻个数
        self.res = res  # FAISS GPU资源对象

    # 训练函数：将训练数据加入到FAISS索引中
    def fit(self, X, y):
        # 初始化 self.index 为一个FAISS索引: IndexFlatL2, 该索引使用欧氏距离进行搜索
        self.index = faiss.IndexFlatL2(X.shape[1])  # L2距离 (欧氏距离)
        
        # 如果有GPU资源对象，则将索引转移到GPU上
        if self.res is not None:
            self.index = faiss.index_cpu_to_gpu(self.res, 0, self.index)
        
        # 添加数据到索引
        self.index.add(X.astype(np.float32))  # FAISS 处理的是float32类型的数据

        # 初始化 self.y 为传入的 y
        self.y = y

    # 预测函数：对新的数据集X进行分类预测
    def predict(self, X):
        # 搜索X中每个向量的k个最近邻
        distances, indices = self.index.search(X.astype(np.float32), self.k)
        votes = self.y[indices]  # 根据索引获得最近邻的标签
        # 通过投票机制得出最终预测的标签
        predictions = np.array([np.argmax(np.bincount(vote)) for vote in votes])
        return predictions

    # 评分函数：计算预测准确率
    def score(self, X, y):
        # 预测并比较预测结果和真实标签，计算准确率
        predictions = self.predict(X)
        accuracy = np.mean(predictions == y)  # 比较预测标签和真实标签，计算准确率
        return accuracy
